{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## 学习方法：\n",
    "- 1. 边用边学\n",
    "- 2. 直接上案例，先跑起来\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: torch in /Users/aiquantong/anaconda3/envs/keras/lib/python3.9/site-packages (2.2.2)\n",
      "Requirement already satisfied: torchvision in /Users/aiquantong/anaconda3/envs/keras/lib/python3.9/site-packages (0.17.2)\n",
      "Requirement already satisfied: filelock in /Users/aiquantong/anaconda3/envs/keras/lib/python3.9/site-packages (from torch) (3.19.1)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /Users/aiquantong/anaconda3/envs/keras/lib/python3.9/site-packages (from torch) (4.15.0)\n",
      "Requirement already satisfied: sympy in /Users/aiquantong/anaconda3/envs/keras/lib/python3.9/site-packages (from torch) (1.14.0)\n",
      "Requirement already satisfied: networkx in /Users/aiquantong/anaconda3/envs/keras/lib/python3.9/site-packages (from torch) (3.2.1)\n",
      "Requirement already satisfied: jinja2 in /Users/aiquantong/anaconda3/envs/keras/lib/python3.9/site-packages (from torch) (3.1.6)\n",
      "Requirement already satisfied: fsspec in /Users/aiquantong/anaconda3/envs/keras/lib/python3.9/site-packages (from torch) (2025.10.0)\n",
      "Requirement already satisfied: numpy in /Users/aiquantong/anaconda3/envs/keras/lib/python3.9/site-packages (from torchvision) (2.0.2)\n",
      "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /Users/aiquantong/anaconda3/envs/keras/lib/python3.9/site-packages (from torchvision) (11.3.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /Users/aiquantong/anaconda3/envs/keras/lib/python3.9/site-packages (from jinja2->torch) (3.0.3)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /Users/aiquantong/anaconda3/envs/keras/lib/python3.9/site-packages (from sympy->torch) (1.3.0)\n"
     ]
    }
   ],
   "source": [
    "# 安装 pytorch \n",
    "!pip install torch torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: numpy in /Users/aiquantong/anaconda3/envs/keras/lib/python3.9/site-packages (2.0.2)\n"
     ]
    }
   ],
   "source": [
    "# 更新numpy\n",
    "!pip install --upgrade numpy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## Mnist 分类任务：\n",
    "1. 网络基本构建与训练方法，常用函数解析\n",
    "2. torch.nn.function模块\n",
    "3. nn.Module模块\n",
    "\n",
    "\n",
    "### 读取Mnist数据集\n",
    "* 会自动进行下载\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.2.2\n"
     ]
    }
   ],
   "source": [
    "import torch \n",
    "print(torch.__version__)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting matplotlib\n",
      "  Downloading matplotlib-3.9.4-cp39-cp39-macosx_10_12_x86_64.whl.metadata (11 kB)\n",
      "Collecting contourpy>=1.0.1 (from matplotlib)\n",
      "  Downloading contourpy-1.3.0-cp39-cp39-macosx_10_9_x86_64.whl.metadata (5.4 kB)\n",
      "Collecting cycler>=0.10 (from matplotlib)\n",
      "  Downloading cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)\n",
      "Collecting fonttools>=4.22.0 (from matplotlib)\n",
      "  Downloading fonttools-4.60.1-cp39-cp39-macosx_10_9_x86_64.whl.metadata (112 kB)\n",
      "Collecting kiwisolver>=1.3.1 (from matplotlib)\n",
      "  Downloading kiwisolver-1.4.7-cp39-cp39-macosx_10_9_x86_64.whl.metadata (6.3 kB)\n",
      "Requirement already satisfied: numpy>=1.23 in /Users/aiquantong/anaconda3/envs/keras/lib/python3.9/site-packages (from matplotlib) (2.0.2)\n",
      "Requirement already satisfied: packaging>=20.0 in /Users/aiquantong/anaconda3/envs/keras/lib/python3.9/site-packages (from matplotlib) (25.0)\n",
      "Requirement already satisfied: pillow>=8 in /Users/aiquantong/anaconda3/envs/keras/lib/python3.9/site-packages (from matplotlib) (11.3.0)\n",
      "Collecting pyparsing>=2.3.1 (from matplotlib)\n",
      "  Downloading pyparsing-3.2.5-py3-none-any.whl.metadata (5.0 kB)\n",
      "Requirement already satisfied: python-dateutil>=2.7 in /Users/aiquantong/anaconda3/envs/keras/lib/python3.9/site-packages (from matplotlib) (2.9.0.post0)\n",
      "Collecting importlib-resources>=3.2.0 (from matplotlib)\n",
      "  Downloading importlib_resources-6.5.2-py3-none-any.whl.metadata (3.9 kB)\n",
      "Requirement already satisfied: zipp>=3.1.0 in /Users/aiquantong/anaconda3/envs/keras/lib/python3.9/site-packages (from importlib-resources>=3.2.0->matplotlib) (3.23.0)\n",
      "Requirement already satisfied: six>=1.5 in /Users/aiquantong/anaconda3/envs/keras/lib/python3.9/site-packages (from python-dateutil>=2.7->matplotlib) (1.17.0)\n",
      "Downloading matplotlib-3.9.4-cp39-cp39-macosx_10_12_x86_64.whl (7.9 MB)\n",
      "\u001b[2K   \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.9/7.9 MB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m  \u001b[33m0:00:01\u001b[0mm \u001b[31m7.1 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hDownloading contourpy-1.3.0-cp39-cp39-macosx_10_9_x86_64.whl (265 kB)\n",
      "Downloading cycler-0.12.1-py3-none-any.whl (8.3 kB)\n",
      "Downloading fonttools-4.60.1-cp39-cp39-macosx_10_9_x86_64.whl (2.3 MB)\n",
      "\u001b[2K   \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.3/2.3 MB\u001b[0m \u001b[31m12.2 MB/s\u001b[0m  \u001b[33m0:00:00\u001b[0m \u001b[31m12.7 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hDownloading importlib_resources-6.5.2-py3-none-any.whl (37 kB)\n",
      "Downloading kiwisolver-1.4.7-cp39-cp39-macosx_10_9_x86_64.whl (65 kB)\n",
      "Downloading pyparsing-3.2.5-py3-none-any.whl (113 kB)\n",
      "Installing collected packages: pyparsing, kiwisolver, importlib-resources, fonttools, cycler, contourpy, matplotlib\n",
      "\u001b[2K   \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7/7\u001b[0m [matplotlib]\u001b[0m \u001b[32m6/7\u001b[0m [matplotlib]ourpy]es]\n",
      "\u001b[1A\u001b[2KSuccessfully installed contourpy-1.3.0 cycler-0.12.1 fonttools-4.60.1 importlib-resources-6.5.2 kiwisolver-1.4.7 matplotlib-3.9.4 pyparsing-3.2.5\n"
     ]
    }
   ],
   "source": [
    "# 安装 matplot\n",
    "!pip install matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 开启内置显示模式\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "over\n"
     ]
    }
   ],
   "source": [
    "# 下载文件 \n",
    "#新资源地址 https://github.com/MichalDanielDobrzanski/DeepLearningPython/blob/master/mnist.pkl.gz\n",
    "\n",
    "from pathlib import Path\n",
    "import os\n",
    "import requests\n",
    "\n",
    "DATA_PATH = Path(\"data\")\n",
    "PATH = DATA_PATH / 'mnist'\n",
    "\n",
    "PATH.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "URL = \"http://deeplearning.net/data/mnist/\"\n",
    "FILENAME=\"mnist.pkl.gz\"\n",
    "\n",
    "# if  not (PATH/FILENAME).exists():\n",
    "#     os.remove(PATH/FILENAME)\n",
    "\n",
    "if not (PATH / FILENAME).exists():\n",
    "    req = requests.get(URL+FILENAME)\n",
    "    content = req.content\n",
    "    print(\"content\", content)\n",
    "    (PATH / FILENAME).open('wb').write(content)\n",
    "\n",
    "print(\"over\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data/mnist/mnist.pkl.gz\n",
      "over\n"
     ]
    }
   ],
   "source": [
    "#解压\n",
    "import pickle\n",
    "import gzip\n",
    "\n",
    "print((PATH/FILENAME).as_posix())\n",
    "with gzip.open((PATH / FILENAME).as_posix(), 'rb') as f:\n",
    "    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')\n",
    "x_train[0]\n",
    "y_train[0:10]\n",
    "print(\"over\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(50000, 784)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8ekN5oAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAbe0lEQVR4nO3df2xV9f3H8dflR6+I7e1KbW8rPyygsIlgxqDrVMRRKd1G5McWdS7BzWhwrRGYuNRM0W2uDqczbEz5Y4GxCSjJgEEWNi22ZLNgQBgxbg0l3VpGWyZb7y2FFmw/3z+I98uVFjyXe/u+vTwfySeh955378fjtU9vezn1OeecAADoZ4OsNwAAuDIRIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYGKI9QY+qaenR8eOHVN6erp8Pp/1dgAAHjnn1N7ervz8fA0a1PfrnKQL0LFjxzRq1CjrbQAALlNTU5NGjhzZ5/1J9y249PR06y0AAOLgUl/PExag1atX6/rrr9dVV12lwsJCvfvuu59qjm+7AUBquNTX84QE6PXXX9eyZcu0YsUKvffee5oyZYpKSkp0/PjxRDwcAGAgcgkwffp0V1ZWFvm4u7vb5efnu8rKykvOhkIhJ4nFYrFYA3yFQqGLfr2P+yugM2fOaP/+/SouLo7cNmjQIBUXF6u2tvaC47u6uhQOh6MWACD1xT1AH374obq7u5Wbmxt1e25urlpaWi44vrKyUoFAILJ4BxwAXBnM3wVXUVGhUCgUWU1NTdZbAgD0g7j/PaDs7GwNHjxYra2tUbe3trYqGAxecLzf75ff74/3NgAASS7ur4DS0tI0depUVVVVRW7r6elRVVWVioqK4v1wAIABKiFXQli2bJkWLVqkL3zhC5o+fbpefvlldXR06Nvf/nYiHg4AMAAlJED33HOP/vOf/+jpp59WS0uLbrnlFu3cufOCNyYAAK5cPuecs97E+cLhsAKBgPU2AACXKRQKKSMjo8/7zd8FBwC4MhEgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmhlhvAEgmgwcP9jwTCAQSsJP4KC8vj2nu6quv9jwzYcIEzzNlZWWeZ372s595nrnvvvs8z0hSZ2en55nnn3/e88yzzz7reSYV8AoIAGCCAAEATMQ9QM8884x8Pl/UmjhxYrwfBgAwwCXkZ0A33XST3nrrrf9/kCH8qAkAEC0hZRgyZIiCwWAiPjUAIEUk5GdAhw8fVn5+vsaOHav7779fjY2NfR7b1dWlcDgctQAAqS/uASosLNS6deu0c+dOvfLKK2poaNDtt9+u9vb2Xo+vrKxUIBCIrFGjRsV7SwCAJBT3AJWWluob3/iGJk+erJKSEv3xj39UW1ub3njjjV6Pr6ioUCgUiqympqZ4bwkAkIQS/u6AzMxM3Xjjjaqvr+/1fr/fL7/fn+htAACSTML/HtDJkyd15MgR5eXlJfqhAAADSNwD9Pjjj6umpkb//Oc/9c4772j+/PkaPHhwzJfCAACkprh/C+7o0aO67777dOLECV177bW67bbbtGfPHl177bXxfigAwAAW9wBt2rQp3p8SSWr06NGeZ9LS0jzPfOlLX/I8c9ttt3mekc79zNKrhQsXxvRYqebo0aOeZ1atWuV5Zv78+Z5n+noX7qX87W9/8zxTU1MT02NdibgWHADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgwuecc9abOF84HFYgELDexhXllltuiWlu165dnmf4dzsw9PT0eJ75zne+43nm5MmTnmdi0dzcHNPc//73P88zdXV1MT1WKgqFQsrIyOjzfl4BAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwMQQ6w3AXmNjY0xzJ06c8DzD1bDP2bt3r+eZtrY2zzN33nmn5xlJOnPmjOeZ3/72tzE9Fq5cvAICAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAExwMVLov//9b0xzy5cv9zzzta99zfPMgQMHPM+sWrXK80ysDh486Hnmrrvu8jzT0dHheeamm27yPCNJjz32WExzgBe8AgIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATPicc856E+cLh8MKBALW20CCZGRkeJ5pb2/3PLNmzRrPM5L04IMPep751re+5Xlm48aNnmeAgSYUCl30v3leAQEATBAgAIAJzwHavXu35s6dq/z8fPl8Pm3dujXqfuecnn76aeXl5WnYsGEqLi7W4cOH47VfAECK8Bygjo4OTZkyRatXr+71/pUrV2rVqlV69dVXtXfvXg0fPlwlJSXq7Oy87M0CAFKH59+IWlpaqtLS0l7vc87p5Zdf1g9+8APdfffdkqT169crNzdXW7du1b333nt5uwUApIy4/gyooaFBLS0tKi4ujtwWCARUWFio2traXme6uroUDoejFgAg9cU1QC0tLZKk3NzcqNtzc3Mj931SZWWlAoFAZI0aNSqeWwIAJCnzd8FVVFQoFApFVlNTk/WWAAD9IK4BCgaDkqTW1tao21tbWyP3fZLf71dGRkbUAgCkvrgGqKCgQMFgUFVVVZHbwuGw9u7dq6Kiong+FABggPP8LriTJ0+qvr4+8nFDQ4MOHjyorKwsjR49WkuWLNGPf/xj3XDDDSooKNBTTz2l/Px8zZs3L577BgAMcJ4DtG/fPt15552Rj5ctWyZJWrRokdatW6cnnnhCHR0devjhh9XW1qbbbrtNO3fu1FVXXRW/XQMABjwuRoqU9MILL8Q09/H/UHlRU1Pjeeb8v6rwafX09HieASxxMVIAQFIiQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACa6GjZQ0fPjwmOa2b9/ueeaOO+7wPFNaWup55s9//rPnGcASV8MGACQlAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEFyMFzjNu3DjPM++9957nmba2Ns8zb7/9tueZffv2eZ6RpNWrV3ueSbIvJUgCXIwUAJCUCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATXIwUuEzz58/3PLN27VrPM+np6Z5nYvXkk096nlm/fr3nmebmZs8zGDi4GCkAICkRIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACa4GClgYNKkSZ5nXnrpJc8zs2bN8jwTqzVr1nieee655zzP/Pvf//Y8AxtcjBQAkJQIEADAhOcA7d69W3PnzlV+fr58Pp+2bt0adf8DDzwgn88XtebMmROv/QIAUoTnAHV0dGjKlClavXp1n8fMmTNHzc3NkbVx48bL2iQAIPUM8TpQWlqq0tLSix7j9/sVDAZj3hQAIPUl5GdA1dXVysnJ0YQJE/TII4/oxIkTfR7b1dWlcDgctQAAqS/uAZozZ47Wr1+vqqoq/fSnP1VNTY1KS0vV3d3d6/GVlZUKBAKRNWrUqHhvCQCQhDx/C+5S7r333sifb775Zk2ePFnjxo1TdXV1r38noaKiQsuWLYt8HA6HiRAAXAES/jbssWPHKjs7W/X19b3e7/f7lZGREbUAAKkv4QE6evSoTpw4oby8vEQ/FABgAPH8LbiTJ09GvZppaGjQwYMHlZWVpaysLD377LNauHChgsGgjhw5oieeeELjx49XSUlJXDcOABjYPAdo3759uvPOOyMff/zzm0WLFumVV17RoUOH9Jvf/EZtbW3Kz8/X7Nmz9aMf/Uh+vz9+uwYADHhcjBQYIDIzMz3PzJ07N6bHWrt2recZn8/neWbXrl2eZ+666y7PM7DBxUgBAEmJAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJrgaNoALdHV1eZ4ZMsTzb3fRRx995Hkmlt8tVl1d7XkGl4+rYQMAkhIBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYML71QMBXLbJkyd7nvn617/ueWbatGmeZ6TYLiwaiw8++MDzzO7duxOwE1jgFRAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIKLkQLnmTBhgueZ8vJyzzMLFizwPBMMBj3P9Kfu7m7PM83NzZ5nenp6PM8gOfEKCABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwwcVIkfRiuQjnfffdF9NjxXJh0euvvz6mx0pm+/bt8zzz3HPPeZ75wx/+4HkGqYNXQAAAEwQIAGDCU4AqKys1bdo0paenKycnR/PmzVNdXV3UMZ2dnSorK9OIESN0zTXXaOHChWptbY3rpgEAA5+nANXU1KisrEx79uzRm2++qbNnz2r27Nnq6OiIHLN06VJt375dmzdvVk1NjY4dOxbTL98CAKQ2T29C2LlzZ9TH69atU05Ojvbv368ZM2YoFArp17/+tTZs2KAvf/nLkqS1a9fqs5/9rPbs2aMvfvGL8ds5AGBAu6yfAYVCIUlSVlaWJGn//v06e/asiouLI8dMnDhRo0ePVm1tba+fo6urS+FwOGoBAFJfzAHq6enRkiVLdOutt2rSpEmSpJaWFqWlpSkzMzPq2NzcXLW0tPT6eSorKxUIBCJr1KhRsW4JADCAxBygsrIyvf/++9q0adNlbaCiokKhUCiympqaLuvzAQAGhpj+Imp5ebl27Nih3bt3a+TIkZHbg8Ggzpw5o7a2tqhXQa2trX3+ZUK/3y+/3x/LNgAAA5inV0DOOZWXl2vLli3atWuXCgoKou6fOnWqhg4dqqqqqshtdXV1amxsVFFRUXx2DABICZ5eAZWVlWnDhg3atm2b0tPTIz/XCQQCGjZsmAKBgB588EEtW7ZMWVlZysjI0KOPPqqioiLeAQcAiOIpQK+88ookaebMmVG3r127Vg888IAk6ec//7kGDRqkhQsXqqurSyUlJfrVr34Vl80CAFKHzznnrDdxvnA4rEAgYL0NfAq5ubmeZz73uc95nvnlL3/peWbixImeZ5Ld3r17Pc+88MILMT3Wtm3bPM/09PTE9FhIXaFQSBkZGX3ez7XgAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYCKm34iK5JWVleV5Zs2aNTE91i233OJ5ZuzYsTE9VjJ75513PM+8+OKLnmf+9Kc/eZ45ffq05xmgv/AKCABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwwcVI+0lhYaHnmeXLl3uemT59uueZ6667zvNMsjt16lRMc6tWrfI885Of/MTzTEdHh+cZINXwCggAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMMHFSPvJ/Pnz+2WmP33wwQeeZ3bs2OF55qOPPvI88+KLL3qekaS2traY5gB4xysgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMCEzznnrDdxvnA4rEAgYL0NAMBlCoVCysjI6PN+XgEBAEwQIACACU8Bqqys1LRp05Senq6cnBzNmzdPdXV1UcfMnDlTPp8vai1evDiumwYADHyeAlRTU6OysjLt2bNHb775ps6ePavZs2ero6Mj6riHHnpIzc3NkbVy5cq4bhoAMPB5+o2oO3fujPp43bp1ysnJ0f79+zVjxozI7VdffbWCwWB8dggASEmX9TOgUCgkScrKyoq6/bXXXlN2drYmTZqkiooKnTp1qs/P0dXVpXA4HLUAAFcAF6Pu7m731a9+1d16661Rt69Zs8bt3LnTHTp0yP3ud79z1113nZs/f36fn2fFihVOEovFYrFSbIVCoYt2JOYALV682I0ZM8Y1NTVd9LiqqionydXX1/d6f2dnpwuFQpHV1NRkftJYLBaLdfnrUgHy9DOgj5WXl2vHjh3avXu3Ro4cedFjCwsLJUn19fUaN27cBff7/X75/f5YtgEAGMA8Bcg5p0cffVRbtmxRdXW1CgoKLjlz8OBBSVJeXl5MGwQApCZPASorK9OGDRu0bds2paenq6WlRZIUCAQ0bNgwHTlyRBs2bNBXvvIVjRgxQocOHdLSpUs1Y8YMTZ48OSH/AACAAcrLz33Ux/f51q5d65xzrrGx0c2YMcNlZWU5v9/vxo8f75YvX37J7wOeLxQKmX/fksVisViXvy71tZ+LkQIAEoKLkQIAkhIBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwETSBcg5Z70FAEAcXOrredIFqL293XoLAIA4uNTXc59LspccPT09OnbsmNLT0+Xz+aLuC4fDGjVqlJqampSRkWG0Q3uch3M4D+dwHs7hPJyTDOfBOaf29nbl5+dr0KC+X+cM6cc9fSqDBg3SyJEjL3pMRkbGFf0E+xjn4RzOwzmch3M4D+dYn4dAIHDJY5LuW3AAgCsDAQIAmBhQAfL7/VqxYoX8fr/1VkxxHs7hPJzDeTiH83DOQDoPSfcmBADAlWFAvQICAKQOAgQAMEGAAAAmCBAAwMSACdDq1at1/fXX66qrrlJhYaHeffdd6y31u2eeeUY+ny9qTZw40XpbCbd7927NnTtX+fn58vl82rp1a9T9zjk9/fTTysvL07Bhw1RcXKzDhw/bbDaBLnUeHnjggQueH3PmzLHZbIJUVlZq2rRpSk9PV05OjubNm6e6urqoYzo7O1VWVqYRI0bommuu0cKFC9Xa2mq048T4NOdh5syZFzwfFi9ebLTj3g2IAL3++utatmyZVqxYoffee09TpkxRSUmJjh8/br21fnfTTTepubk5sv7yl79YbynhOjo6NGXKFK1evbrX+1euXKlVq1bp1Vdf1d69ezV8+HCVlJSos7Ozn3eaWJc6D5I0Z86cqOfHxo0b+3GHiVdTU6OysjLt2bNHb775ps6ePavZs2ero6MjcszSpUu1fft2bd68WTU1NTp27JgWLFhguOv4+zTnQZIeeuihqOfDypUrjXbcBzcATJ8+3ZWVlUU+7u7udvn5+a6ystJwV/1vxYoVbsqUKdbbMCXJbdmyJfJxT0+PCwaD7oUXXojc1tbW5vx+v9u4caPBDvvHJ8+Dc84tWrTI3X333Sb7sXL8+HEnydXU1Djnzv27Hzp0qNu8eXPkmL///e9OkqutrbXaZsJ98jw459wdd9zhHnvsMbtNfQpJ/wrozJkz2r9/v4qLiyO3DRo0SMXFxaqtrTXcmY3Dhw8rPz9fY8eO1f3336/GxkbrLZlqaGhQS0tL1PMjEAiosLDwinx+VFdXKycnRxMmTNAjjzyiEydOWG8poUKhkCQpKytLkrR//36dPXs26vkwceJEjR49OqWfD588Dx977bXXlJ2drUmTJqmiokKnTp2y2F6fku5ipJ/04Ycfqru7W7m5uVG35+bm6h//+IfRrmwUFhZq3bp1mjBhgpqbm/Xss8/q9ttv1/vvv6/09HTr7ZloaWmRpF6fHx/fd6WYM2eOFixYoIKCAh05ckRPPvmkSktLVVtbq8GDB1tvL+56enq0ZMkS3XrrrZo0aZKkc8+HtLQ0ZWZmRh2bys+H3s6DJH3zm9/UmDFjlJ+fr0OHDun73/++6urq9Pvf/95wt9GSPkD4f6WlpZE/T548WYWFhRozZozeeOMNPfjgg4Y7QzK49957I3+++eabNXnyZI0bN07V1dWaNWuW4c4So6ysTO+///4V8XPQi+nrPDz88MORP998883Ky8vTrFmzdOTIEY0bN66/t9mrpP8WXHZ2tgYPHnzBu1haW1sVDAaNdpUcMjMzdeONN6q+vt56K2Y+fg7w/LjQ2LFjlZ2dnZLPj/Lycu3YsUNvv/121K9vCQaDOnPmjNra2qKOT9XnQ1/noTeFhYWSlFTPh6QPUFpamqZOnaqqqqrIbT09PaqqqlJRUZHhzuydPHlSR44cUV5envVWzBQUFCgYDEY9P8LhsPbu3XvFPz+OHj2qEydOpNTzwzmn8vJybdmyRbt27VJBQUHU/VOnTtXQoUOjng91dXVqbGxMqefDpc5Dbw4ePChJyfV8sH4XxKexadMm5/f73bp169wHH3zgHn74YZeZmelaWlqst9avvve977nq6mrX0NDg/vrXv7ri4mKXnZ3tjh8/br21hGpvb3cHDhxwBw4ccJLcSy+95A4cOOD+9a9/Oeece/75511mZqbbtm2bO3TokLv77rtdQUGBO336tPHO4+ti56G9vd09/vjjrra21jU0NLi33nrLff7zn3c33HCD6+zstN563DzyyCMuEAi46upq19zcHFmnTp2KHLN48WI3evRot2vXLrdv3z5XVFTkioqKDHcdf5c6D/X19e6HP/yh27dvn2toaHDbtm1zY8eOdTNmzDDeebQBESDnnPvFL37hRo8e7dLS0tz06dPdnj17rLfU7+655x6Xl5fn0tLS3HXXXefuueceV19fb72thHv77bedpAvWokWLnHPn3or91FNPudzcXOf3+92sWbNcXV2d7aYT4GLn4dSpU2727Nnu2muvdUOHDnVjxoxxDz30UMr9T1pv//yS3Nq1ayPHnD592n33u991n/nMZ9zVV1/t5s+f75qbm+02nQCXOg+NjY1uxowZLisry/n9fjd+/Hi3fPlyFwqFbDf+Cfw6BgCAiaT/GRAAIDURIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACb+Dwuo74MxItlsAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#展示数据\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "plt.imshow(x_train[0].reshape((28, 28)), cmap='gray')\n",
    "print(x_train.shape)\n",
    "\n",
    "\n",
    "#reshape 将一维数组转换为二维数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting numpy==1.26.4\n",
      "  Downloading numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl.metadata (61 kB)\n",
      "Downloading numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl (20.6 MB)\n",
      "\u001b[2K   \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20.6/20.6 MB\u001b[0m \u001b[31m11.0 MB/s\u001b[0m  \u001b[33m0:00:01\u001b[0m1.7 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m:01\u001b[0m\n",
      "\u001b[?25hInstalling collected packages: numpy\n",
      "  Attempting uninstall: numpy\n",
      "    Found existing installation: numpy 2.0.2\n",
      "    Uninstalling numpy-2.0.2:\n",
      "      Successfully uninstalled numpy-2.0.2\n",
      "Successfully installed numpy-1.26.4\n"
     ]
    }
   ],
   "source": [
    "!pip install numpy==1.26.4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]]) tensor([5, 0, 4,  ..., 8, 4, 8])\n",
      "-------------\n",
      "torch.Size([50000, 784])\n",
      "-------------\n",
      "torch.Size([50000])\n",
      "-------------\n",
      "tensor(0.) tensor(9)\n"
     ]
    }
   ],
   "source": [
    "# 预处理\n",
    "# torch [GPU] 和 numpy [CPU]\n",
    "# tensor 和   ndarray\n",
    "\n",
    "# pip install numpy==1.26.4  降级\n",
    "# Could not infer dtype of numpy.float32\n",
    "\n",
    "\n",
    "import torch\n",
    "\n",
    "x_train, y_train, x_valid, y_valid = map(\n",
    "    torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
    ")\n",
    "\n",
    "n, c = x_train.shape\n",
    "x_train, x_train.shape, y_train.min(), y_train.max()\n",
    "print(x_train, y_train)\n",
    "print(\"-------------\")\n",
    "print(x_train.shape)\n",
    "print(\"-------------\")\n",
    "print(y_train.shape)\n",
    "print(\"-------------\")\n",
    "print(x_train.min(), y_train.max())\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# torch说明\n",
    "\n",
    "torch.nn.functional  和 nn.Module\n",
    "\n",
    "模型有可以学习参数 用 nn.Module\n",
    "否者用 nn.funcational\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "loss_func = F.cross_entropy\n",
    "\n",
    "def model(xb): \n",
    "    return xb.mm(weights) + bias\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(10.9519, grad_fn=<NllLossBackward0>)\n"
     ]
    }
   ],
   "source": [
    "bs = 64 \n",
    "xb = x_train[0:bs]\n",
    "yb = y_train[0:bs]\n",
    "weights = torch.randn([784, 10], dtype = torch.float, requires_grad = True)\n",
    "bias = torch.zeros(10, requires_grad = True)\n",
    "print(loss_func(model(xb), yb))\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 创建一个model来简化代码\n",
    "\n",
    "FC = wx+b\n",
    "\n",
    "1. 继承nn.Module, 需要调用nn.Module的构造函数\n",
    "2. 无需反向传播函数，nn.Module可以利用autograd自动实现\n",
    "3. Module中的可学习参数可以通过named_paramters()或者paramters()返回迭代器\n",
    "\n",
    "\n",
    "# drop-out \n",
    "1. 防止过拟合 去掉过多神经元\n",
    "2. 全连接层一般都是需要调用dropout\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "\n",
    "class Mnist_NN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.hidden1 = nn.Linear(784, 128)\n",
    "        self.hidden2 = nn.Linear(128, 256)\n",
    "        self.out = nn.Linear(256, 10)\n",
    "        self.dropout = nn.Dropout(0.5)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.hidden1(x))\n",
    "        x = self.dropout(x)\n",
    "        x = F.relu(self.hidden2(x))\n",
    "        x = self.out(x)\n",
    "        return x\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mnist_NN(\n",
      "  (hidden1): Linear(in_features=784, out_features=128, bias=True)\n",
      "  (hidden2): Linear(in_features=128, out_features=256, bias=True)\n",
      "  (out): Linear(in_features=256, out_features=10, bias=True)\n",
      "  (dropout): Dropout(p=0.5, inplace=False)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = Mnist_NN()\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hidden1.weight Parameter containing:\n",
      "tensor([[-0.0025, -0.0180,  0.0190,  ...,  0.0347, -0.0221,  0.0016],\n",
      "        [-0.0022, -0.0042, -0.0213,  ...,  0.0061,  0.0283, -0.0344],\n",
      "        [ 0.0164,  0.0127, -0.0126,  ..., -0.0033, -0.0120, -0.0044],\n",
      "        ...,\n",
      "        [-0.0010,  0.0037,  0.0085,  ...,  0.0082,  0.0162,  0.0149],\n",
      "        [-0.0188, -0.0240, -0.0302,  ..., -0.0141,  0.0140,  0.0271],\n",
      "        [-0.0303, -0.0023, -0.0161,  ..., -0.0174,  0.0228, -0.0291]],\n",
      "       requires_grad=True) torch.Size([128, 784])\n",
      "hidden1.bias Parameter containing:\n",
      "tensor([-0.0103,  0.0243, -0.0153, -0.0245, -0.0333,  0.0314,  0.0078, -0.0087,\n",
      "         0.0118,  0.0175,  0.0261,  0.0225, -0.0274, -0.0270,  0.0054, -0.0042,\n",
      "        -0.0169, -0.0282,  0.0029,  0.0083, -0.0137,  0.0189,  0.0114, -0.0093,\n",
      "        -0.0124, -0.0023,  0.0329, -0.0030, -0.0158,  0.0181,  0.0314,  0.0051,\n",
      "        -0.0066, -0.0203, -0.0190,  0.0292,  0.0017, -0.0049,  0.0354,  0.0112,\n",
      "         0.0072, -0.0204,  0.0162, -0.0129,  0.0106,  0.0181, -0.0314, -0.0272,\n",
      "         0.0205,  0.0185,  0.0343, -0.0166, -0.0165,  0.0296,  0.0281, -0.0222,\n",
      "        -0.0158, -0.0295,  0.0256, -0.0016,  0.0221,  0.0049, -0.0068,  0.0114,\n",
      "         0.0297,  0.0083, -0.0008, -0.0172,  0.0356, -0.0201, -0.0290,  0.0103,\n",
      "        -0.0071, -0.0248,  0.0005,  0.0070,  0.0253,  0.0199, -0.0095, -0.0125,\n",
      "         0.0312,  0.0245,  0.0117,  0.0197,  0.0245, -0.0339,  0.0025,  0.0132,\n",
      "        -0.0284, -0.0217, -0.0187, -0.0245, -0.0184,  0.0138, -0.0126,  0.0277,\n",
      "        -0.0006,  0.0327,  0.0024, -0.0129,  0.0238,  0.0057, -0.0151,  0.0036,\n",
      "         0.0031,  0.0259, -0.0287, -0.0180, -0.0293, -0.0249,  0.0069, -0.0046,\n",
      "        -0.0251,  0.0147,  0.0292,  0.0063,  0.0286, -0.0174, -0.0096, -0.0263,\n",
      "        -0.0029, -0.0311,  0.0224,  0.0308,  0.0204, -0.0038,  0.0077,  0.0268],\n",
      "       requires_grad=True) torch.Size([128])\n",
      "hidden2.weight Parameter containing:\n",
      "tensor([[ 0.0815,  0.0171,  0.0091,  ..., -0.0792, -0.0195, -0.0046],\n",
      "        [ 0.0875,  0.0251, -0.0797,  ..., -0.0116,  0.0381,  0.0333],\n",
      "        [-0.0597, -0.0232, -0.0034,  ..., -0.0082, -0.0387, -0.0299],\n",
      "        ...,\n",
      "        [ 0.0661, -0.0260,  0.0829,  ..., -0.0315, -0.0273, -0.0565],\n",
      "        [-0.0345,  0.0748,  0.0305,  ..., -0.0265, -0.0139,  0.0705],\n",
      "        [ 0.0120, -0.0251, -0.0863,  ...,  0.0816,  0.0120,  0.0817]],\n",
      "       requires_grad=True) torch.Size([256, 128])\n",
      "hidden2.bias Parameter containing:\n",
      "tensor([-0.0820,  0.0004,  0.0849,  0.0133,  0.0644, -0.0649, -0.0864, -0.0768,\n",
      "        -0.0203,  0.0313, -0.0488, -0.0272, -0.0015, -0.0742, -0.0507,  0.0673,\n",
      "         0.0289, -0.0397,  0.0153,  0.0330,  0.0551,  0.0353,  0.0642, -0.0402,\n",
      "        -0.0522, -0.0734, -0.0878,  0.0257, -0.0717, -0.0546, -0.0095, -0.0879,\n",
      "         0.0452,  0.0336,  0.0843,  0.0562, -0.0388,  0.0068,  0.0543, -0.0356,\n",
      "        -0.0220, -0.0731,  0.0269,  0.0189,  0.0142, -0.0660,  0.0415,  0.0723,\n",
      "        -0.0256,  0.0214,  0.0807,  0.0386,  0.0834,  0.0834,  0.0858, -0.0371,\n",
      "        -0.0549,  0.0515, -0.0644,  0.0478,  0.0804,  0.0686,  0.0249, -0.0655,\n",
      "        -0.0168, -0.0005,  0.0535, -0.0530, -0.0033, -0.0006, -0.0011, -0.0650,\n",
      "        -0.0338, -0.0778,  0.0457,  0.0250,  0.0455, -0.0883,  0.0619,  0.0088,\n",
      "         0.0347, -0.0170, -0.0884,  0.0297,  0.0304,  0.0595, -0.0800, -0.0376,\n",
      "        -0.0682, -0.0331, -0.0432, -0.0255, -0.0193,  0.0249, -0.0832,  0.0160,\n",
      "         0.0105, -0.0796, -0.0544, -0.0795, -0.0398,  0.0575, -0.0326, -0.0721,\n",
      "         0.0239, -0.0844, -0.0608,  0.0606, -0.0116,  0.0112,  0.0213,  0.0425,\n",
      "         0.0876, -0.0732, -0.0579,  0.0209,  0.0701, -0.0818,  0.0371, -0.0156,\n",
      "         0.0666, -0.0443,  0.0484, -0.0367, -0.0075,  0.0505, -0.0832,  0.0229,\n",
      "         0.0201,  0.0524,  0.0068, -0.0819,  0.0507,  0.0447, -0.0791, -0.0488,\n",
      "        -0.0046,  0.0018, -0.0278,  0.0394,  0.0879, -0.0490,  0.0726,  0.0571,\n",
      "         0.0012, -0.0127, -0.0387, -0.0545, -0.0422, -0.0156, -0.0587,  0.0828,\n",
      "         0.0204, -0.0667, -0.0462, -0.0483, -0.0270, -0.0171,  0.0664,  0.0100,\n",
      "         0.0254,  0.0881,  0.0782,  0.0649, -0.0622,  0.0792,  0.0479, -0.0830,\n",
      "        -0.0157,  0.0758,  0.0695, -0.0786,  0.0719,  0.0859, -0.0534, -0.0280,\n",
      "        -0.0223, -0.0812, -0.0624,  0.0072,  0.0852,  0.0310,  0.0559, -0.0200,\n",
      "         0.0087,  0.0662,  0.0635, -0.0474, -0.0496, -0.0121,  0.0385, -0.0104,\n",
      "        -0.0598,  0.0195, -0.0664, -0.0094,  0.0257,  0.0498,  0.0142,  0.0123,\n",
      "        -0.0775,  0.0845, -0.0229, -0.0471, -0.0469, -0.0106,  0.0627, -0.0548,\n",
      "        -0.0808, -0.0194, -0.0573, -0.0302, -0.0030, -0.0139, -0.0805,  0.0206,\n",
      "         0.0879, -0.0703,  0.0165, -0.0210, -0.0303, -0.0454, -0.0562, -0.0493,\n",
      "         0.0695, -0.0027, -0.0191,  0.0257, -0.0882,  0.0785, -0.0072,  0.0292,\n",
      "         0.0097,  0.0112, -0.0506, -0.0313, -0.0880, -0.0859, -0.0497, -0.0389,\n",
      "         0.0613,  0.0447,  0.0883,  0.0109,  0.0510, -0.0566, -0.0554,  0.0457,\n",
      "         0.0352, -0.0850, -0.0321, -0.0531,  0.0071,  0.0412,  0.0171,  0.0576],\n",
      "       requires_grad=True) torch.Size([256])\n",
      "out.weight Parameter containing:\n",
      "tensor([[ 0.0188,  0.0071, -0.0119,  ...,  0.0160, -0.0016,  0.0169],\n",
      "        [ 0.0512,  0.0205, -0.0298,  ..., -0.0050,  0.0326, -0.0151],\n",
      "        [ 0.0484,  0.0503,  0.0546,  ..., -0.0170,  0.0415,  0.0295],\n",
      "        ...,\n",
      "        [ 0.0559,  0.0368, -0.0052,  ..., -0.0020, -0.0566,  0.0296],\n",
      "        [ 0.0093,  0.0127, -0.0069,  ...,  0.0246, -0.0312,  0.0475],\n",
      "        [ 0.0099,  0.0150,  0.0080,  ...,  0.0368,  0.0423, -0.0436]],\n",
      "       requires_grad=True) torch.Size([10, 256])\n",
      "out.bias Parameter containing:\n",
      "tensor([-0.0332,  0.0144, -0.0189, -0.0385, -0.0453,  0.0512,  0.0136,  0.0361,\n",
      "         0.0561, -0.0252], requires_grad=True) torch.Size([10])\n"
     ]
    }
   ],
   "source": [
    "for name, parameter in net.named_parameters():\n",
    "    print(name, parameter, parameter.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用TensorDataset和DataLoader来简化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import TensorDataset\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "train_ds = TensorDataset(x_train, y_train)\n",
    "train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)\n",
    "\n",
    "# DataLoader 打包GPU数据\n",
    "# shuffle 洗牌操作\n",
    "# 以后DIY自己测试数据集合\n",
    "\n",
    "valid_ds = TensorDataset(x_valid, y_valid)\n",
    "valid_dl = DataLoader(valid_ds, batch_size=bs * 2)\n",
    "\n",
    "def get_data(train_ds, valid_ds, bs):\n",
    "    return(\n",
    "        DataLoader(train_ds, batch_size=bs, shuffle=True),\n",
    "        DataLoader(valid_ds, batch_size=bs * 2)\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. 一般在训练模型加上model.train(),正常使用Batch Normalization和Dropout\n",
    "2. 测试的时候一般选择model.eval(), 就不用Batch Normalization和Dropout\n",
    "\n",
    "zip的用法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from torch import optim\n",
    "\n",
    "def fit(steps, model, loss_func, opt, train_dl, valid_dl):\n",
    "    for step in range(steps):\n",
    "        #训练模式\n",
    "        model.train()\n",
    "        for xb, yb in train_dl:\n",
    "            loss_batch(model, loss_func, xb, yb, opt)\n",
    "        \n",
    "        #验证模式\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            losses, nums = zip(\n",
    "                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]\n",
    "            )\n",
    "        val_loss = np.sum(np.multiply(losses, nums) / np.sum(nums))\n",
    "        print(\"step\" + str(step), 'loss: ' + str(val_loss))\n",
    "        \n",
    "def get_model():\n",
    "    model = Mnist_NN()\n",
    "    return model, optim.Adam(model.parameters(), lr=0.001)\n",
    "        \n",
    "def loss_batch(model, loss_func, xb, yb, opt=None):\n",
    "    loss = loss_func(model(xb), yb)\n",
    "    \n",
    "    if opt is not None:\n",
    "        loss.backward() #反向传播\n",
    "        opt.step() # 执行更新\n",
    "        opt.zero_grad() #本次归零\n",
    "    \n",
    "    return loss.item(), len(xb)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step0 loss: 0.16728072224259374\n",
      "step1 loss: 0.1299698649019003\n",
      "step2 loss: 0.10939429392553865\n",
      "step3 loss: 0.10529611498489976\n",
      "step4 loss: 0.09242153388019653\n",
      "step5 loss: 0.09359208861691877\n",
      "step6 loss: 0.0835448778225109\n",
      "step7 loss: 0.08987521542813628\n",
      "step8 loss: 0.09060409275982528\n",
      "step9 loss: 0.08350172254936769\n",
      "step10 loss: 0.08170419474693481\n",
      "step11 loss: 0.0798896457316354\n",
      "step12 loss: 0.08407680627792143\n",
      "step13 loss: 0.08100661720724311\n",
      "step14 loss: 0.07725527662937527\n",
      "step15 loss: 0.07406482595845591\n",
      "step16 loss: 0.07745582395065577\n",
      "step17 loss: 0.07868006379744037\n",
      "step18 loss: 0.07290566928330808\n",
      "step19 loss: 0.07483596621258184\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-43-a15bd33868d4>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mtrain_dl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalid_dl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_ds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalid_ds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_dl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalid_dl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-42-7b7192aaa6b4>\u001b[0m in \u001b[0;36mfit\u001b[0;34m(steps, model, loss_func, opt, train_dl, valid_dl)\u001b[0m\n\u001b[1;32m      7\u001b[0m         \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtrain_dl\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m             \u001b[0mloss_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m         \u001b[0;31m#验证模式\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-42-7b7192aaa6b4>\u001b[0m in \u001b[0;36mloss_batch\u001b[0;34m(model, loss_func, xb, yb, opt)\u001b[0m\n\u001b[1;32m     23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     24\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mloss_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m     \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mopt\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1192\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1193\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1195\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1196\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-35-e8dda8fcc57c>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhidden1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     13\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     14\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhidden2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1192\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1193\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1195\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1196\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    112\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    113\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    116\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# 开启训练\n",
    "\n",
    "train_dl, valid_dl = get_data(train_ds, valid_ds, bs)\n",
    "model, opt = get_model()\n",
    "fit(100, model, loss_func, opt, train_dl, valid_dl)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(1, 4), (2, 5), (3, 6)]\n",
      "(1, 2, 3)\n",
      "(4, 5, 6)\n"
     ]
    }
   ],
   "source": [
    "a = [1,2,3]\n",
    "b = [4,5,6]\n",
    "zipped=zip(a,b)\n",
    "print(list(zipped))\n",
    "a2,b2=zip(*zip(a,b))\n",
    "print(a2)\n",
    "print(b2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# epoch  整个数据集训练一次\n",
    "# bacth  一次迭代\n",
    "\n",
    "模型训练 得到最好的一次参数组合"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 计算准确率\n",
    "\n",
    "correct = 0\n",
    "total = 0\n",
    "for xb, yb in valid_dl:\n",
    "    outputs = model(xb)\n",
    "    _, predicted = torch.max(outputs.data, 1)\n",
    "    total += yb_size(0)\n",
    "    correct += (predicted == yb).sum().item()\n",
    "    print(100*correct/total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
